# -*- coding: utf-8 -*-  
'''
Created on 2021年10月2日
@author: luoyi
'''
import os
import sys
#    取项目根目录
ROOT_PATH = os.path.abspath(os.path.dirname(__file__)).split('NLU')[0]
ROOT_PATH = ROOT_PATH + "NLU"
sys.path.append(ROOT_PATH)

import utils.conf as conf
from data.dataset_baidu_nlu_generator import QuestionIterator
from data.dataset_baidu_nlu_tfrecord import TFRecordWriter
from utils.dictionary import Dictionaries
from utils.relationship import Relationship


#    初始化字典，标注字典
Dictionaries.instance().load_from_pkl()
print('初始化字典, 总字数：', Dictionaries.instance().size())
Relationship.instance().load_from_file()
print('初始化关系库，关系实体标注。总关系数:', Relationship.instance().rel_size(), ' 实体标注数：', Relationship.instance().sot_size())
print(Relationship.instance()._sot2id)


qi = QuestionIterator()
#    写入训练集
tfw = TFRecordWriter(q_iter=qi)
c = tfw.write(fpath=conf.DATASET_BAIDU.get_question_train_data_path(), 
              tfrecord_fpath=conf.DATASET_BAIDU.get_tfrecord_train_data_path())
#    总训练集样本数量： 1452723
print('总训练集样本数量：', c)

#    写入验证集
tfw = TFRecordWriter(q_iter=qi)
c = tfw.write(fpath=conf.DATASET_BAIDU.get_question_val_data_path(), 
              tfrecord_fpath=conf.DATASET_BAIDU.get_tfrecord_val_data_path())
#    总验证集样本数量： 181793
print('总验证集样本数量：', c)